import pickle
import os, sys
import cv2
import math
import plyfile
import numpy as np
from transformers import CLIPModel, AutoProcessor
from tqdm import tqdm
import time
from functools import partial
import glob
from PIL import Image
import json
import open3d as o3d
import viser

# sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from fusion_util import make_intrinsic, adjust_intrinsic, PointCloudToImageMapper
from scenefun3d.utils.box_utils import box3d_iou, construct_bbox_corners
from scenefun3d.utils.pc_util import point_cloud_to_bbox

unique_multiple_lookup_file = 'Projects/TimeZero/dataset/scanrefer/unique_multiple_lookup.json'
unique_multiple_lookup = json.load(open(unique_multiple_lookup_file))


def read_aggregation(filename):
    object_id_to_segs = {}
    label_to_segs = {}
    with open(filename) as f:
        data = json.load(f)
        num_objects = len(data['segGroups'])
        for i in range(num_objects):
            object_id = data['segGroups'][i]['objectId'] + 1     # instance ids should be 1-indexed
            label = data['segGroups'][i]['label']
            segs = data['segGroups'][i]['segments']
            object_id_to_segs[object_id] = segs
            if label in label_to_segs:
                label_to_segs[label].extend(segs)
            else:
                label_to_segs[label] = segs
    return object_id_to_segs, label_to_segs


def read_segmentation(filename):
    seg_to_verts = {}
    with open(filename) as f:
        data = json.load(f)
        num_verts = len(data['segIndices'])
        for i in range(num_verts):
            seg_id = data['segIndices'][i]
            if seg_id in seg_to_verts:
                seg_to_verts[seg_id].append(i)
            else:
                seg_to_verts[seg_id] = [i]
    return seg_to_verts, num_verts


def export(root_path, scene_id):
    agg_file = os.path.join(root_path, 'scans', scene_id, f'{scene_id}.aggregation.json')
    seg_file = os.path.join(root_path, 'scans', scene_id, f'{scene_id}_vh_clean_2.0.010000.segs.json')

    object_id_to_segs, label_to_segs = read_aggregation(agg_file)
    seg_to_verts, num_verts = read_segmentation(seg_file)

    object_ids = np.zeros(shape=(num_verts), dtype=np.uint32)     # 0: unannotated
    for object_id, segs in object_id_to_segs.items():
        for seg in segs:
            verts = seg_to_verts[seg]
            object_ids[verts] = object_id
    return object_ids


def process_scan(root_path, image_path, mask_pred, obj_id, thresh, acc, point2img_mapper, total):

    scene_id = image_path.split('/')[-2]

    # print(f'Processing {scene_id} ...')
    # image_dir = os.path.join(image_root, scene_id, 'color')

    # image_dir = 'vis/scene0000_00/pred'
    # image_list = os.listdir(image_dir)
    # print('Number of images: {}.'.format(len(image_list)))

    # load intrinsic parameter
    # intrinsics = np.loadtxt(os.path.join(image_root, scene_id, 'intrinsic_color.txt'))

    # os.makedirs(os.path.join(root_path, '2d_bbox', scene_id), exist_ok=True)

    # load point cloud
    coord = plyfile.PlyData().read(os.path.join(root_path, 'scans', scene_id, f'{scene_id}_vh_clean_2.ply'))
    v = np.array([list(x) for x in coord.elements[0]])
    coords = np.ascontiguousarray(v[:, :3])     # (N, 3)
    colors = np.ascontiguousarray(v[:, 3:6]) / 255.0     # (N, 3)

    # load instance label
    ins_labels = export(root_path, scene_id)     # (N, )

    gt_pcd = coords[ins_labels == obj_id]
    gt_color = colors[ins_labels == obj_id]

    # batch_labels, obj_ids, inst_locs, center, batch_pcds = load_pc(scene_id)

    # server = viser.ViserServer()
    # server.scene.add_point_cloud(f"scene", points=coords - np.mean(coords, axis=0), colors=colors, point_size=0.01)
    # server.scene.add_point_cloud(f"gt", points=gt_pcd - np.mean(coords, axis=0), colors=gt_color, point_size=0.01)

    pred_pcd = []

    # for image_name in image_list:
    # print(image_name)

    color = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
    image_dim = (mask_pred.shape[1], mask_pred.shape[0])

    # depth_path = image_path.replace('pred', 'depth').replace('jpg', 'png')
    # depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED) / 1000.0     # convert to meter

    # pose_path = image_path.replace('pred', 'pose').replace('jpg', 'txt')
    # pose = np.loadtxt(pose_path)

    image_name = image_path.split('/')[-1]

    depth = cv2.imread(os.path.join(root_path, 'depth', scene_id, image_name.replace('jpg', 'png')),
                       cv2.IMREAD_UNCHANGED) / 1000.0     # convert to meter
    pose = np.loadtxt(os.path.join(root_path, 'poses', scene_id, image_name.replace('jpg', 'txt')))

    pc = coords

    # link = np.ones([pc.shape[0], 4], dtype=int)
    link = point2img_mapper.compute_mapping(pose, pc, depth, image_dim)

    # color[link[:, 0], link[:, 1]] = 255
    # cv2.imwrite(f'vis/proj.png', color)
    # cv2.imwrite(f'vis/mask.png', mask_pred * 255)

    link[:, -1] = link[:, -1] * mask_pred[link[:, 0], link[:, 1]]     # / 255
    pred_pcd = pc[link[:, -1] != 0]
    pred_color = colors[link[:, -1] != 0]

    # print(pred_pcd.shape)
    if pred_pcd.shape[0] == 0:
        pred_box = np.zeros((6,), dtype=np.float32)
    else:
        pred_box = point_cloud_to_bbox(pred_pcd)
    gt_box = point_cloud_to_bbox(gt_pcd)

    # print(pred_box, gt_box)
    pred_box_corners = construct_bbox_corners(pred_box[:3], pred_box[3:])
    gt_box_corners = construct_bbox_corners(gt_box[:3], gt_box[3:])
    # print(pred_box_corners, gt_box_corners)
    iou = box3d_iou(pred_box_corners, gt_box_corners)
    # print(f"3D IOU: {iou}")

    acc['all'] += (iou >= thresh).astype(np.int32)

    unique_multiple = unique_multiple_lookup[scene_id][str(obj_id - 1)]
    if unique_multiple == 0:
        acc['unique'] += (iou >= thresh).astype(np.int32)
        total['unique'] += 1
    else:
        acc['multiple'] += (iou >= thresh).astype(np.int32)
        total['multiple'] += 1
    total['all'] += 1

    # server.scene.add_point_cloud("pcd", points=pred_pcd - np.mean(coords, axis=0), colors=pred_color, point_size=0.01)

    # while True:     # keep server alive
    #     time.sleep(0.2)


def main():

    # 定义图像尺寸和深度图
    img_dim = (640, 480)
    depth_scale = 1000.0
    fx = 577.870605
    fy = 577.870605
    mx = 319.5
    my = 239.5
    intrinsics = make_intrinsic(fx=fx, fy=fy, mx=mx, my=my)
    intrinsics = adjust_intrinsic(intrinsics, intrinsic_image_dim=[640, 480], image_dim=img_dim)

    root_path = '/data2/datasets/scannet/'

    # scan_id_file = "scan_data/ScanNet/annotations/splits/scannetv2_train.txt"
    # scene_list = list(set([x.strip() for x in open(scan_id_file, 'r')]))
    # scene_list = list(filter(lambda x: x.endswith('00'), scene_list))
    # scene_list = ['scene0000_00']

    data = pickle.load(open("Projects/TimeZero/dataset/scanrefer/mask_val.pkl", 'rb'))

    # tokenizer_name = '/data1/huggingface/clip-vit-base-patch16'
    # clip = CLIPModel.from_pretrained(tokenizer_name)     #.cuda()
    # clip.eval()
    # processor = AutoProcessor.from_pretrained(tokenizer_name)

    visibility_threshold = 0.25     # threshold for the visibility check
    cut_num_pixel_boundary = 1     # do not use the features on the image boundary

    thresh = np.array([0.25, 0.5])
    acc = {"all": np.array([0, 0]), "unique": np.array([0, 0]), "multiple": np.array([0, 0])}
    total = {"all": 1e-6, "unique": 1e-6, "multiple": 1e-6}

    # calculate image pixel-3D points correspondances
    point2img_mapper = PointCloudToImageMapper(image_dim=img_dim,
                                               intrinsics=intrinsics,
                                               visibility_threshold=visibility_threshold,
                                               cut_bound=cut_num_pixel_boundary)

    for item in tqdm(data):
        # print(item)
        image_path = item['image']
        mask_pred = item['mask_pred']
        obj_id = item['obj_id']

        process_scan(root_path, image_path, mask_pred, obj_id, thresh, acc, point2img_mapper, total)

        # break

        # for th, r in zip(thresh, acc):
        #     print(f'Acc@{th}:', r / len(data))

    for key, value in acc.items():
        for th, r in zip(thresh, value):
            print(f'{key} Acc@{th}:', r / total[key])

    # process_func = partial(process_scan,
    #                        root_path=root_path,
    #                        image_root=image_root,
    #                        point2img_mapper=point2img_mapper,
    #                        processor=processor,
    #                        clip=clip)

    # # process_func = partial(process_scan_pred,
    # #                        root_path=root_path,
    # #                        image_root=image_root,
    # #                        point2img_mapper=point2img_mapper,
    # #                        predictor=None)

    # feats_2d = {}
    # for scene_id in tqdm(scene_list):
    #     feats = process_func(scene_id)
    #     # feats_2d[scene_id] = feats

    #     break

    # pickle.dump(feats_2d, open('data/scannet/feats_2d.pkl', 'wb'))


if __name__ == '__main__':
    main()
